None
# code for loading the format for the notebook
import os
# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', '..', 'notebook_format'))
from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
# 1. magic for inline plot
# 2. magic to print version
# 3. magic so that the notebook will reload external python modules
# 4. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'
import os
import cv2
import timm
import random
import torch
import transformers
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import albumentations as A
import pytorch_lightning as pl
from typing import Optional
from dataclasses import dataclass
from time import perf_counter
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.utilities.model_summary import ModelSummary
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer
)
# https://discuss.huggingface.co/t/get-using-the-call-method-is-faster-warning-with-datacollatorwithpadding/23924
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
%watermark -a 'Ethen' -d -u -v -iv
Self-supervision, a.k.a. pre-training methods has been all the rage lately. The core idea behind this is supervised learning that has been the main workhorse in machine learning applications requires labeled data. In real world scenarios getting large amounts of labeled data can be very expensive, not to mention we might need to continuously annotate ground truth for new data to ensure our system can adapt to newer information. One of the benefits of self-supervised learning is to reduce the amount of labeling required. Given large amounts of un-labeled data at hand, we would create proxy tasks from the data itself and pre-train our model on them, these tasks are basically turning our un-supervised learning into a supervised learning problem. By warming up our models, the hope is that we can then achieve competitive results on the downstream applications that we are actually interested in by fine-tuning it on a smaller set of labeled data.
In this document, we'll go over one of pre-training methods in computer vision and natural language process field called CLIP (Contrastive Language-Image Pre-training) [6] [7]. The de-facto approach to a lot of vision tasks in this deep learning era is to start from pretrained visual representations, potentially trained via supervised learning on ImageNet dataset. CLIP demonstrated a pre-training task of predicting which caption goes with which image via contrastive loss is an efficient and scalable way to learn SOTA image representations. Directly quoting from the original paper:
The model transfers non-trivially to most tasks and is often competitive with a fully supervised baseline without the need for any dataset specific training. For instance, we match the accuracy of the original ResNet-50 on ImageNet zero-shot without needing to use any of the 1.28 million training examples it was trained on.
Personally, one of the most interesting things about this method is its multi-modality nature. i.e. there are pure text pre-training methods such as BERT or pure image pre-training methods such as SIMCLR, but using natural language supervision to guide image representation learning is definitely quite unique.
@dataclass
class Config:
cache_dir: str = "./clip"
base_dir: str = os.path.join(cache_dir, "flickr30k_images")
image_dir: str = os.path.join(base_dir, "flickr30k_images")
captions_path: str = os.path.join(base_dir, "results.csv")
val_size: float = 0.1
batch_size: int = 64
num_workers: int = 4
seed: int = 42
lr: float = 0.0001
weight_decay: float = 0.0005
temperature: float = 1.0
epochs: int = 4
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_checkpoint: Optional[str] = None
image_encoder_model: str = 'resnet50'
image_embedding_dim: int = 2048
image_size: int = 224
text_encoder_model: str = "distilbert-base-uncased"
text_embedding_dim: int = 768
text_max_length: int = 200
# for both image encoder and text encoder
pretrained: bool = True
trainable: bool = True
# projection head
projection_dim: int = 256
dropout: float = 0.1
def __post_init__(self):
random.seed(self.seed)
np.random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed_all(self.seed)
config = Config()
We'll be using publicly available Flickr30k from Kaggle as our example dataset, other dataset choices with reasonable sizes are Flickr8k and MS-COCO Captions.
# sample command to download the dataset using kaggle API and unzip it
kaggle datasets download -d hsankesara/flickr-image-dataset -p .
unzip flickr-image-dataset.zip .
The next few code chunks performs the usual of reading in our dataset, creating train/validation split. Read in some sample dataset for manual inspection. One important thing to note about this dataset is each image is paired with multiple captions, 5 to be exact.
def create_train_val_df(captions_path, val_size):
df = pd.read_csv(
captions_path,
sep="|",
skiprows=1,
names=["image", "caption_number", "caption"]
)
# remove extra white space up front
df['caption'] = df['caption'].str.lstrip()
df['caption_number'] = df['caption_number'].str.lstrip()
# one of the rows is corrupted
df.loc[19999, 'caption_number'] = "4"
df.loc[19999, 'caption'] = "A dog runs across the grass ."
max_id = df.shape[0] // 5
df["ids"] = [id_ for id_ in range(max_id) for _ in range(5)]
image_ids = np.arange(0, max_id)
val_ids = np.random.choice(
image_ids,
size=int(val_size * len(image_ids)),
replace=False
)
train_ids = [id_ for id_ in image_ids if id_ not in val_ids]
df_train = df[df["ids"].isin(train_ids)].reset_index(drop=True)
df_val = df[df["ids"].isin(val_ids)].reset_index(drop=True)
return df_train, df_val
df_train, df_val = create_train_val_df(config.captions_path, config.val_size)
print("train shape: ", df_train.shape)
print("val shape: ", df_val.shape)
df_train.head()
image_file = df_train["image"].iloc[0]
caption = df_train["caption"].iloc[0]
image = cv2.imread(f"{config.image_dir}/{image_file}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
print(caption)
plt.imshow(image)
plt.show()
For our dataset, we'll keep it minimalistic and use:
class CLIPDataset(Dataset):
"""
image_filenames and cpations must have the same length; so, if there are
multiple captions for each image, the image_filenames must have repetitive
file names
"""
def __init__(self, image_dir, image_filenames, image_size, captions, tokenizer, text_max_length):
self.image_dir = image_dir
self.image_filenames = image_filenames
self.captions = captions
self.tokenizer = tokenizer
self.text_max_length = text_max_length
self.transforms = self.get_transforms(image_size)
def __getitem__(self, idx):
encoded = self.tokenizer(
self.captions[idx],
truncation=True,
max_length=self.text_max_length
)
image = cv2.imread(f"{self.image_dir}/{self.image_filenames[idx]}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.transforms(image=image)["image"]
encoded["image"] = image
return encoded
def __len__(self):
return len(self.captions)
def get_transforms(self, image_size):
return A.Compose(
[
A.Resize(config.image_size, config.image_size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True)
]
)
class DataCollatorForClip:
"""Data collator that will dynamically pad the text and image inputs received."""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, features):
text_feature = {
"input_ids": [feature["input_ids"] for feature in features],
"attention_mask": [feature["attention_mask"] for feature in features]
}
batch = self.tokenizer.pad(
text_feature,
padding=True,
return_tensors="pt"
)
# convert to channel first
image_feature = [torch.FloatTensor(feature["image"]) for feature in features]
image = torch.stack(image_feature).permute(0, 3, 1, 2)
batch["image"] = image
return batch
def build_data_loaders(config, df, tokenizer, mode):
dataset = CLIPDataset(
config.image_dir,
list(df["image"].values),
config.image_size,
list(df["caption"].values),
tokenizer,
config.text_max_length,
)
data_collator = DataCollatorForClip(tokenizer)
data_loader = DataLoader(
dataset,
batch_size=config.batch_size,
num_workers=config.num_workers,
collate_fn=data_collator,
pin_memory=True,
shuffle=True if mode == "train" else False
)
return data_loader
tokenizer = AutoTokenizer.from_pretrained(config.text_encoder_model)
data_loader_train = build_data_loaders(config, df_train, tokenizer, mode="train")
data_loader_val = build_data_loaders(config, df_val, tokenizer, mode="val")
batch = next(iter(data_loader_train))
print("text shape: ", batch["input_ids"].shape)
print("image shape: ", batch["image"].shape)
CLIP model comprises of three components [6]: image encoder, text encoder and projection head (absorbed inside encoder block in the diagram).
During training we'll need to feed our batches of text and images through its own respective encoder. Given a batch of $n$ text, image pairs, $\text{text}_i, \text{image}_i$, CLIP is trained to predict which of the $n × n$ possible (image, text) pairings across a batch actually occurred. To do this, CLIP learns a multi-modal embedding space by jointly training an image encoder and text encoder to maximize the $N$ real image and text embedding pairs' (cosine) similarity in a given batch while minimizing cosine similarity of $n^2 − n$ incorrect image and text embedding pairings. This is commonly referred to as InfoNCE loss.
\begin{align} L = -\frac{1}{n} \sum^n_{i=1} \frac{exp(sim(\text{text}_i, \text{image}_i))}{\sum_j exp(sim(\text{text}_i, \text{image}_j))} \end{align}Where $sim$ is a similarity function such as cosine similarity or dot product that aims to pull . This can then be treated as a mult-class classification task using cross entropy loss, with the difference that number of class here will be the number of negatives within that batch instead of the distinct number of classes that we wish to predict in a multi-class classification problem.
Here we will be implementing the clip components ourselves, but huggingface transformers also has a clip implementation which we can use for reference.
timm and text encoder from huggingface transformers.Feel free to experiment with different image or text encoders, as well as projection head.
class ProjectionHead(nn.Module):
def __init__(
self,
embedding_dim,
projection_dim,
dropout
):
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self, x):
projected = self.projection(x)
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
x = x + projected
x = self.layer_norm(x)
return x
class ImageEncoder(nn.Module):
def __init__(
self, model_name, pretrained, trainable, embedding_dim, projection_dim, dropout
):
super().__init__()
self.model = timm.create_model(
model_name, pretrained, num_classes=0, global_pool="avg"
)
for p in self.model.parameters():
p.requires_grad = trainable
self.projection = ProjectionHead(embedding_dim, projection_dim, dropout)
def forward(self, img_arr):
image_feature = self.model(img_arr)
image_embedding = self.projection(image_feature)
return image_embedding
image_encoder = ImageEncoder(
model_name=config.image_encoder_model,
pretrained=config.pretrained,
trainable=config.trainable,
embedding_dim=config.image_embedding_dim,
projection_dim=config.projection_dim,
dropout=config.dropout
)
image_embedding = image_encoder(batch["image"])
image_embedding.shape
class TextEncoder(nn.Module):
def __init__(
self, model_name, pretrained, trainable, embedding_dim, projection_dim, dropout
):
super().__init__()
if pretrained:
self.model = AutoModel.from_pretrained(model_name)
else:
config = AutoConfig.from_pretrained(model_name)
self.model = AutoModel.from_config(config)
for p in self.model.parameters():
p.requires_grad = trainable
self.projection = ProjectionHead(embedding_dim, projection_dim, dropout)
def forward(self, input_ids, attention_mask):
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
# we are using CLS token hidden representation as sentence's embedding
text_feature = output.last_hidden_state[:, 0, :]
text_embedding = self.projection(text_feature)
return text_embedding
text_encoder = TextEncoder(
model_name=config.text_encoder_model,
pretrained=config.pretrained,
trainable=config.trainable,
embedding_dim=config.text_embedding_dim,
projection_dim=config.projection_dim,
dropout=config.dropout
)
text_embedding = text_encoder(batch["input_ids"], batch["attention_mask"])
text_embedding.shape
Upon confirming both our image and text encoder returns an embedding of the same shape, we can now assemble the clip model.
class ClipModel(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.image_encoder = ImageEncoder(
model_name=config.image_encoder_model,
pretrained=config.pretrained,
trainable=config.trainable,
embedding_dim=config.image_embedding_dim,
projection_dim=config.projection_dim,
dropout=config.dropout
)
self.text_encoder = TextEncoder(
model_name=config.text_encoder_model,
pretrained=config.pretrained,
trainable=config.trainable,
embedding_dim=config.text_embedding_dim,
projection_dim=config.projection_dim,
dropout=config.dropout
)
self.save_hyperparameters()
def forward(self, batch):
# Getting text and image embedding (with same dimension)
text_embedding = self.text_encoder(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"]
)
image_embedding = self.image_encoder(batch["image"])
return text_embedding, image_embedding
def _compute_loss(self, text_embedding, image_embedding):
# @ is equivalent to torch.mm, calculating similarity between text
# and image embedding
logits = (text_embedding @ image_embedding.T) / self.config.temperature
targets = torch.arange(logits.shape[0], device=config.device, dtype=torch.long)
text_loss = F.cross_entropy(logits, targets, reduction="none")
image_loss = F.cross_entropy(logits.T, targets, reduction="none")
loss = (image_loss + text_loss) / 2.0
return loss.mean()
def training_step(self, batch, batch_idx):
text_embedding, image_embedding = self.forward(batch)
loss = self._compute_loss(text_embedding, image_embedding)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
text_embedding, image_embedding = self.forward(batch)
loss = self._compute_loss(text_embedding, image_embedding)
self.log("val_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.config.lr,
weight_decay=self.config.weight_decay
)
return optimizer
clip = ClipModel(config)
if not config.model_checkpoint:
trainer = pl.Trainer(
accelerator="gpu",
devices=-1,
max_epochs=config.epochs,
precision=16,
# note, we purpose-fully disabled the progress bar to prevent flooding our notebook's console
# in normal settings, we can/should definitely turn it on
enable_progress_bar=False,
log_every_n_steps=500,
)
t1_start = perf_counter()
trainer.fit(clip, data_loader_train, data_loader_val)
t1_stop = perf_counter()
print("Training elapsed time:", t1_stop - t1_start)
clip = clip.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
else:
clip = clip.load_from_checkpoint(config.model_checkpoint)
# set to evaluation after training, we'll be doing inferencing in the next section
clip.to(config.device)
clip.eval()
ModelSummary(clip, max_depth=1)
For inferencing, what we'll do is giving the model a piece of text and have it retrieve the most relevant images from an unseen validation (or test) set.
def get_image_embeddings(image_encoder, data_loader):
image_embeddings = []
with torch.no_grad():
for batch in data_loader:
image_embedding = image_encoder(batch["image"].to(config.device))
image_embeddings.append(image_embedding)
return torch.cat(image_embeddings)
# generate embeddings for all distinct images in our validation dataset
df_test = df_val[df_val["caption_number"] == "0"]
data_loader_test = build_data_loaders(config, df_test, tokenizer, mode="test")
t1_start = perf_counter()
image_embeddings = get_image_embeddings(clip.image_encoder, data_loader_test)
t1_stop = perf_counter()
print("Image embedding prediction elapsed time:", t1_stop - t1_start)
def find_matches(model, image_embeddings, query, image_filenames, n):
"""
This is a crude way of computing top-k cosine similarity between
input query and all of our image embeddings.
"""
encoded_query = tokenizer([query])
batch = {
key: torch.tensor(values).to(config.device)
for key, values in encoded_query.items()
}
with torch.no_grad():
text_embeddings = model.text_encoder(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"]
)
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = text_embeddings_n @ image_embeddings_n.T
values, indices = torch.topk(dot_similarity.squeeze(0), n)
matches = [image_filenames[idx] for idx in indices]
_, axes = plt.subplots(3, 3, figsize=(10, 10))
for match, ax in zip(matches, axes.flatten()):
image = cv2.imread(f"{config.image_dir}/{match}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
ax.imshow(image)
ax.axis("off")
plt.show()
find_matches(
clip,
image_embeddings,
query="dogs on the grass",
image_filenames=df_test['image'].values,
n=9
)
Quick qualitative analysis shows our model shows our model is performing quite well on image retrieval task, of course, we would like to perform quantitative evaluation on downstream benchmark tasks and let actual numbers speak for itself. One common setup for doing this called linear probing, where we add a linear classifier on top of our pre-trained embedding and fine-tune it on a downstream supervised learning task while keeping the pre-trained embedding frozen during training. Will leave this for future posts.
In this document we went through how we can use PyTorch Lightning to implement CLIP like models.
Some of other key learnings from the work includes: